import torch
import torch.nn as nn
import torch.nn.functional as F


class AttentionLSTM(nn.Module):
    """
    LSTM with attention mechanism for processing support set
    """
    def __init__(self, input_size, hidden_size):
        super(AttentionLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.attention = nn.Linear(hidden_size * 2, 1)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, input_size)
        lstm_out, _ = self.lstm(x)  # (batch_size, seq_len, hidden_size * 2)
        
        # Compute attention weights
        attention_weights = F.softmax(self.attention(lstm_out), dim=1)  # (batch_size, seq_len, 1)
        
        # Apply attention
        attended_output = torch.sum(attention_weights * lstm_out, dim=1)  # (batch_size, hidden_size * 2)
        
        return attended_output


class FullContextEmbedding(nn.Module):
    """
    Full Context Embedding (FCE) for processing support set with bidirectional LSTM
    """
    def __init__(self, input_size, hidden_size):
        super(FullContextEmbedding, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Bidirectional LSTM for processing support set
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        
    def forward(self, support_embeddings, support_labels, num_classes):
        """
        Args:
            support_embeddings: (num_support_total, embedding_dim)
            support_labels: (num_support_total,)
            num_classes: number of classes
        Returns:
            contextualized_embeddings: (num_support_total, hidden_size * 2)
        """
        batch_size = support_embeddings.size(0)
        
        # Reshape for LSTM processing
        # We'll process all support examples as a sequence
        x = support_embeddings.unsqueeze(0)  # (1, num_support_total, embedding_dim)
        
        # Process through bidirectional LSTM
        lstm_out, _ = self.lstm(x)  # (1, num_support_total, hidden_size * 2)
        
        # Remove batch dimension
        contextualized_embeddings = lstm_out.squeeze(0)  # (num_support_total, hidden_size * 2)
        
        return contextualized_embeddings


class CosineAttention(nn.Module):
    """
    Cosine similarity based attention mechanism
    """
    def __init__(self, embedding_dim):
        super(CosineAttention, self).__init__()
        self.embedding_dim = embedding_dim
        
    def forward(self, query_embedding, support_embeddings, support_labels):
        """
        Args:
            query_embedding: (embedding_dim,)
            support_embeddings: (num_support, embedding_dim)
            support_labels: (num_support,)
        Returns:
            attention_weights: (num_support,)
        """
        # Compute cosine similarity
        query_norm = F.normalize(query_embedding.unsqueeze(0), p=2, dim=1)
        support_norm = F.normalize(support_embeddings, p=2, dim=1)
        
        similarities = torch.mm(query_norm, support_norm.t()).squeeze(0)  # (num_support,)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(similarities, dim=0)
        
        return attention_weights


class DotProductAttention(nn.Module):
    """
    Dot product attention mechanism
    """
    def __init__(self, embedding_dim):
        super(DotProductAttention, self).__init__()
        self.embedding_dim = embedding_dim
        
    def forward(self, query_embedding, support_embeddings, support_labels):
        """
        Args:
            query_embedding: (embedding_dim,)
            support_embeddings: (num_support, embedding_dim)
            support_labels: (num_support,)
        Returns:
            attention_weights: (num_support,)
        """
        # Compute dot product similarities
        similarities = torch.mv(support_embeddings, query_embedding)  # (num_support,)
        
        # Apply softmax to get attention weights
        attention_weights = F.softmax(similarities, dim=0)
        
        return attention_weights
